{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2fb43f21",
   "metadata": {
    "id": "2fb43f21"
   },
   "source": [
    "# Introduction\n",
    "\n",
    "The goal of this tutorial is to demonstrate the basic steps required to setup and train the Maxine Background Noise removal network [1] in NeMo.\n",
    "\n",
    "This notebook covers the following steps:\n",
    "\n",
    "* Download speech and noise data\n",
    "* Prepare the training data by mixing speech and noise\n",
    "* Configure and train a simple single-output model\n",
    "\n",
    "Note that this tutorial is only for demonstration purposes.\n",
    "\n",
    "*Disclaimer:*\n",
    "User is responsible for checking the content of datasets and the applicable licenses and determining if suitable for the intended use."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50525ece",
   "metadata": {
    "id": "50525ece"
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n",
    "\n",
    "Instructions for setting up Colab are as follows:\n",
    "1. Open a new Python 3 notebook.\n",
    "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n",
    "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n",
    "4. Run this cell to set up dependencies.\n",
    "5. Restart the runtime (Runtime -> Restart Runtime) for any upgraded packages to take effect\n",
    "\"\"\"\n",
    "\n",
    "GIT_USER = 'NVIDIA'\n",
    "BRANCH = 'main'\n",
    "\n",
    "if 'google.colab' in str(get_ipython()):\n",
    "\n",
    "    # Install dependencies\n",
    "    !pip install wget\n",
    "    !apt-get install sox libsndfile1 ffmpeg\n",
    "    !pip install text-unidecode\n",
    "    !pip install matplotlib>=3.3.2\n",
    "\n",
    "    ## Install NeMo\n",
    "    !python -m pip install git+https://github.com/{GIT_USER}/NeMo.git@{BRANCH}#egg=nemo_toolkit[all]\n",
    "\n",
    "    ## Install TorchAudio\n",
    "    !pip install torchaudio>=0.13.0 -f https://download.pytorch.org/whl/torch_stable.html"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "765e8b8a",
   "metadata": {
    "id": "765e8b8a"
   },
   "source": [
    "The following cell will take care of the necessary imports and prepare utility functions used throughout the notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9248dd5f",
   "metadata": {
    "id": "9248dd5f"
   },
   "outputs": [],
   "source": [
    "import glob\n",
    "import librosa\n",
    "import os\n",
    "import torch\n",
    "import tqdm\n",
    "\n",
    "import IPython.display as ipd\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import lightning.pytorch as pl\n",
    "import soundfile as sf\n",
    "\n",
    "from omegaconf import OmegaConf, open_dict\n",
    "from pathlib import Path\n",
    "from torchmetrics.functional.audio import signal_distortion_ratio, scale_invariant_signal_distortion_ratio\n",
    "\n",
    "from nemo.utils.notebook_utils import download_an4\n",
    "from nemo.collections.asr.parts.preprocessing.segment import AudioSegment\n",
    "from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest\n",
    "\n",
    "\n",
    "# Utility functions for displaying signals and metrics\n",
    "def show_signal(signal: np.ndarray, sample_rate: int = 16000, tag: str = 'Signal'):\n",
    "    \"\"\"Show the time-domain signal and its spectrogram.\n",
    "    \"\"\"\n",
    "    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 2.5))\n",
    "\n",
    "    # show waveform\n",
    "    t = np.arange(0, len(signal)) / sample_rate\n",
    "\n",
    "    ax[0].plot(t, signal)\n",
    "    ax[0].set_xlim(0, t.max())\n",
    "    ax[0].grid()\n",
    "    ax[0].set_xlabel('time / s')\n",
    "    ax[0].set_ylabel('amplitude')\n",
    "    ax[0].set_title(tag)\n",
    "\n",
    "    n_fft = 1024\n",
    "    hop_length = 256\n",
    "\n",
    "    D = librosa.amplitude_to_db(np.abs(librosa.stft(signal, n_fft=n_fft, hop_length=hop_length)), ref=np.max)\n",
    "    img = librosa.display.specshow(D, y_axis='linear', x_axis='time', sr=sample_rate, n_fft=n_fft, hop_length=hop_length, ax=ax[1])\n",
    "    ax[1].set_title(tag)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.colorbar(img, format=\"%+2.f dB\", ax=ax)\n",
    "\n",
    "def show_metrics(signal: np.ndarray, reference: np.ndarray, sample_rate: int = 16000, tag: str = 'Signal'):\n",
    "    \"\"\"Show metrics for the time-domain signal and the reference signal.\n",
    "    \"\"\"\n",
    "    sdr = signal_distortion_ratio(preds=torch.tensor(signal), target=torch.tensor(reference))\n",
    "    sisdr = scale_invariant_signal_distortion_ratio(preds=torch.tensor(signal), target=torch.tensor(reference))\n",
    "    print(tag)\n",
    "    print('\\tsdr:  ', sdr.item())\n",
    "    print('\\tsisdr:', sisdr.item())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bfa2199e",
   "metadata": {
    "id": "bfa2199e"
   },
   "source": [
    "### Data preparation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0aa05a36",
   "metadata": {
    "id": "0aa05a36"
   },
   "source": [
    "In this notebook, it is assumed that all audio will be resampled to 16kHz and the data and configuration will be stored under `root_dir` as defined below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e790630",
   "metadata": {
    "id": "8e790630"
   },
   "outputs": [],
   "source": [
    "# sample rate used throughout the notebook\n",
    "sample_rate = 16000\n",
    "\n",
    "# root directory for data preparation, configurations, etc\n",
    "root_dir = Path('./')\n",
    "\n",
    "# data directory\n",
    "data_dir = root_dir / 'data'\n",
    "data_dir.mkdir(exist_ok=True)\n",
    "\n",
    "# scripts directory\n",
    "scripts_dir = root_dir / 'scripts'\n",
    "scripts_dir.mkdir(exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9af717ad",
   "metadata": {
    "id": "9af717ad"
   },
   "source": [
    "Clean speech data is used to prepare datasets used for training a simple speech enhancement model.\n",
    "\n",
    "In this tutorial, a subset of LibriSpeech dataset [2] will be downloaded and used as the speech material.\n",
    "The following cell will download and prepare the speech data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1c5a60a",
   "metadata": {
    "id": "e1c5a60a"
   },
   "outputs": [],
   "source": [
    "speech_dir = data_dir / 'speech'\n",
    "speech_data_set = 'mini'\n",
    "\n",
    "# Copy script\n",
    "get_librispeech_script = os.path.join(scripts_dir, 'get_librispeech_data.py')\n",
    "if not os.path.exists(get_librispeech_script):\n",
    "    !wget -P $scripts_dir https://raw.githubusercontent.com/{GIT_USER}/NeMo/{BRANCH}/scripts/dataset_processing/get_librispeech_data.py\n",
    "\n",
    "# Download the data\n",
    "if not speech_dir.is_dir():\n",
    "    speech_dir.mkdir(exist_ok=True)\n",
    "    !python {get_librispeech_script} --data_root={speech_dir} --data_set={speech_data_set}\n",
    "else:\n",
    "    print('Speech dataset already exists in:', speech_dir)\n",
    "\n",
    "# Reduce the size of test dataset for this tutorial to 1000 clean utterances for train and 100 clean utterances for test\n",
    "train_metadata = read_manifest(speech_dir / 'train_clean_5.json')\n",
    "write_manifest(speech_dir / 'train.json', train_metadata[:1000])\n",
    "\n",
    "test_metadata = read_manifest(speech_dir / 'dev_clean_2.json')\n",
    "write_manifest(speech_dir / 'test.json', test_metadata[:100])\n",
    "\n",
    "# Speech manifests\n",
    "speech_manifest = {\n",
    "    'train': speech_dir / 'train.json',\n",
    "    'test': speech_dir / 'test.json',\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06f07374",
   "metadata": {
    "id": "06f07374"
   },
   "source": [
    "Noise data will be mixed with the downloaded speech data to prepare a noisy dataset.\n",
    "\n",
    "The following cell will download and prepare the noise data using a subset of the DEMAND dataset [3] will be downloaded and used as the noise data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1de1089",
   "metadata": {
    "id": "b1de1089"
   },
   "outputs": [],
   "source": [
    "noise_dir = data_dir / 'noise'\n",
    "noise_data_set = 'STRAFFIC,PSTATION'\n",
    "\n",
    "# Copy script\n",
    "get_demand_script = os.path.join(scripts_dir, 'get_demand_data.py')\n",
    "if not os.path.exists(get_demand_script):\n",
    "    !wget -P $scripts_dir https://raw.githubusercontent.com/{GIT_USER}/NeMo/{BRANCH}/scripts/dataset_processing/get_demand_data.py\n",
    "\n",
    "if not noise_dir.is_dir():\n",
    "    noise_dir.mkdir(exist_ok=True)\n",
    "    !python {get_demand_script} --data_root={noise_dir} --data_sets={noise_data_set}\n",
    "else:\n",
    "    print('Noise directory already exists in:', noise_dir)\n",
    "\n",
    "\n",
    "def create_noise_manifest(base_dir, subset, offset=0, duration=None):\n",
    "    \"\"\"Split the noise data set into train and test subsets.\n",
    "    \"\"\"\n",
    "    complete_noise_manifests = glob.glob(str(base_dir / 'manifests' / '*.json'))\n",
    "    subset_noise_manifest = base_dir / f'{subset}_manifest.json'\n",
    "    \n",
    "    subset_metadata = []\n",
    "\n",
    "    for noise_manifest in complete_noise_manifests:\n",
    "        complete_metadata = read_manifest(noise_manifest)\n",
    "    \n",
    "        for item in complete_metadata:\n",
    "            new_item = item.copy()\n",
    "            new_item['offset'] = offset\n",
    "            new_item['duration'] = duration\n",
    "            subset_metadata.append(new_item)\n",
    "\n",
    "    write_manifest(subset_noise_manifest.as_posix(), subset_metadata)\n",
    "\n",
    "    return subset_noise_manifest\n",
    "\n",
    "noise_manifest = {\n",
    "    'train': create_noise_manifest(noise_dir, 'train', offset=0, duration=200),\n",
    "    'test': create_noise_manifest(noise_dir, 'test', offset=200, duration=100),\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1378d78a",
   "metadata": {
    "id": "1378d78a"
   },
   "source": [
    "For this tutorial, a single-channel noisy dataset is constructed by adding speech and noise.\n",
    "\n",
    "The following block will add speech and noise and save the noisy data. The noisy data is created by mixing speech and noise at a few pre-defined signal-to-noise ratios (SNRs). Note that a separate manifest will be created for each SNR."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6c678a0",
   "metadata": {
    "id": "d6c678a0"
   },
   "outputs": [],
   "source": [
    "%%capture\n",
    "# Suppress output of this cell, since the script used below is relatively verbose.\n",
    "\n",
    "# Copy script\n",
    "add_noise_script = os.path.join(scripts_dir, 'add_noise.py')\n",
    "if not os.path.exists(add_noise_script):\n",
    "    !wget -P $scripts_dir https://raw.githubusercontent.com/{GIT_USER}/NeMo/{BRANCH}/scripts/dataset_processing/add_noise.py\n",
    "\n",
    "# Generate noisy datasets and save the noise component as well.\n",
    "noisy_dir = data_dir / 'noisy'\n",
    "noisy_dir.mkdir(exist_ok=True)\n",
    "\n",
    "for subset in ['train', 'test']:\n",
    "    noisy_subset_dir = noisy_dir / subset\n",
    "\n",
    "    if not noisy_subset_dir.is_dir():\n",
    "        noisy_subset_dir.mkdir(exist_ok=True)\n",
    "        !python {add_noise_script} --input_manifest={speech_manifest[subset]} --noise_manifest={noise_manifest[subset]} --out_dir={noisy_subset_dir} --snrs 0 5 10 15 20 --num_workers 4 --save_noise"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4cd1426",
   "metadata": {
    "id": "c4cd1426"
   },
   "source": [
    "Training a model requires an input dataset which includes information about the noisy input signal and the desired (target) output signal.\n",
    "\n",
    "In this tutorial, train and test manifests are created by combining the information from the speech manifests and each noisy manifest generated in the previous step. Note that the final manifests include `noisy_filepath`, `speech_filepath` and `noise_filepath`. These keys can be used to define the input signal and the output signal for the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcfeb639",
   "metadata": {
    "id": "fcfeb639"
   },
   "outputs": [],
   "source": [
    "dataset_manifest = {\n",
    "    'train': data_dir / 'dataset_train.json',\n",
    "    'test': data_dir / 'dataset_test.json',\n",
    "}\n",
    "\n",
    "for subset in ['train', 'test']:\n",
    "    # Load clean manifest\n",
    "    speech_metadata = read_manifest(speech_manifest[subset])\n",
    "\n",
    "    # Load noisy manifests\n",
    "    noisy_manifests = glob.glob(str(noisy_dir / subset / 'manifests/*.json'))\n",
    "    noisy_manifests.sort()\n",
    "\n",
    "    subset_metadata = []\n",
    "\n",
    "    for noisy_manifest in noisy_manifests:\n",
    "        noisy_metadata = read_manifest(noisy_manifest)\n",
    "\n",
    "        for speech_item, noisy_item in tqdm.tqdm(zip(speech_metadata, noisy_metadata), total=len(noisy_metadata)):\n",
    "            # Check that the file matches\n",
    "            assert os.path.basename(speech_item['audio_filepath']) == os.path.basename(noisy_item['audio_filepath']), f'Speech: {speech_item}. Noisy: {noisy_item}'\n",
    "\n",
    "            # Create a new item for the subset manifest\n",
    "            subset_item = {\n",
    "                'noisy_filepath': noisy_item['audio_filepath'],\n",
    "                'speech_filepath': speech_item['audio_filepath'],\n",
    "                'noise_filepath': noisy_item['noise_filepath'],\n",
    "                'duration': noisy_item['duration'],\n",
    "                'offset': noisy_item.get('offset', 0)\n",
    "            }\n",
    "\n",
    "            subset_metadata.append(subset_item)\n",
    "\n",
    "    # Save the subset manifest\n",
    "    write_manifest(dataset_manifest[subset].as_posix(), subset_metadata)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa4248ab",
   "metadata": {
    "id": "aa4248ab"
   },
   "source": [
    "### Model configuration\n",
    "\n",
    "We use the SEASR model architecture for training the dataset.",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "002b917c",
   "metadata": {
    "id": "002b917c"
   },
   "outputs": [],
   "source": [
    "config_dir = root_dir / 'conf'\n",
    "config_dir.mkdir(exist_ok=True)\n",
    "\n",
    "config_path = config_dir / 'maxine_bnr.yaml'\n",
    "\n",
    "if not config_path.is_file():\n",
    "    !wget https://raw.githubusercontent.com/{GIT_USER}/NeMo/{BRANCH}/examples/audio/conf/maxine_bnr.yaml -P {config_dir.as_posix()}\n",
    "\n",
    "config = OmegaConf.load(config_path)\n",
    "config = OmegaConf.to_container(config, resolve=True)\n",
    "config = OmegaConf.create(config)\n",
    "\n",
    "print('Loaded config')\n",
    "print(OmegaConf.to_yaml(config))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65cab3ad",
   "metadata": {
    "id": "65cab3ad"
   },
   "source": [
    "Training dataset is configured with the following parameters\n",
    "* `manifest_filepath` points to a manifest file, with each line containing a dictionary corresponding to a single example\n",
    "* `input_key` is the key corresponding to the input audio signal in the example dictionary\n",
    "* `target_key` is the key corresponding to the desired output (target) audio signal in the example dictionary\n",
    "* `min_duration` can be used to filter out short examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91717b15",
   "metadata": {
    "id": "91717b15"
   },
   "outputs": [],
   "source": [
    "# Setup training dataset\n",
    "config.model.train_ds.manifest_filepath = dataset_manifest['train'].as_posix()\n",
    "config.model.train_ds.input_key = 'noisy_filepath'\n",
    "config.model.train_ds.target_key = 'speech_filepath'\n",
    "config.model.train_ds.min_duration = 0 # load all audio files, without filtering short ones\n",
    "\n",
    "print(\"Train dataset config:\")\n",
    "print(OmegaConf.to_yaml(config.model.train_ds))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c11592e4",
   "metadata": {
    "id": "c11592e4"
   },
   "source": [
    "Validation and test datasets can be configured in the same way as the training dataset. Here, we use the same dataset for validation and testing purposes for simplicity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8bc0a0a",
   "metadata": {
    "id": "d8bc0a0a"
   },
   "outputs": [],
   "source": [
    "# Use test manifest for validation and test sets\n",
    "config.model.validation_ds.manifest_filepath = dataset_manifest['test'].as_posix()\n",
    "config.model.validation_ds.input_key = 'noisy_filepath'\n",
    "config.model.validation_ds.target_key = 'speech_filepath'\n",
    "\n",
    "config.model.test_ds.manifest_filepath = dataset_manifest['test'].as_posix()\n",
    "config.model.test_ds.input_key = 'noisy_filepath'\n",
    "config.model.test_ds.target_key = 'speech_filepath'\n",
    "\n",
    "print(\"Validation dataset config:\")\n",
    "print(OmegaConf.to_yaml(config.model.validation_ds))\n",
    "\n",
    "print(\"Test dataset config:\")\n",
    "print(OmegaConf.to_yaml(config.model.test_ds))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1bc0f8f3",
   "metadata": {
    "id": "1bc0f8f3"
   },
   "source": [
    "Metrics for validation and test set are configured in the following cell.\n",
    "\n",
    "In this tutorial, signal-to-distortion ratio (SDR) and scale-invariant SDR from torch metrics are used [4]."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82de1dbb",
   "metadata": {
    "id": "82de1dbb"
   },
   "outputs": [],
   "source": [
    "# Setup metrics to compute on validation and test sets\n",
    "metrics = OmegaConf.create({\n",
    "    'sisdr': {\n",
    "        '_target_': 'torchmetrics.audio.ScaleInvariantSignalDistortionRatio',\n",
    "    },\n",
    "    'sdr': {\n",
    "        '_target_': 'torchmetrics.audio.SignalDistortionRatio',\n",
    "    }\n",
    "})\n",
    "config.model.metrics.val = metrics\n",
    "config.model.metrics.test = metrics\n",
    "\n",
    "print(\"Metrics config:\")\n",
    "print(OmegaConf.to_yaml(config.model.metrics))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4eda2d65",
   "metadata": {
    "id": "4eda2d65"
   },
   "source": [
    "### Trainer configuration\n",
    "NeMo models are primarily PyTorch Lightning modules and therefore are entirely compatible with the PyTorch Lightning ecosystem."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e53ba6ef",
   "metadata": {
    "id": "e53ba6ef"
   },
   "outputs": [],
   "source": [
    "print(\"Trainer config:\")\n",
    "print(OmegaConf.to_yaml(config.trainer))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9201689",
   "metadata": {
    "id": "a9201689"
   },
   "source": [
    "We can modify some trainer configs for this tutorial.\n",
    "Most importantly, the number of epochs is set to a small value, to limit the runtime for the purpose of this example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17b7499f",
   "metadata": {
    "id": "17b7499f"
   },
   "outputs": [],
   "source": [
    "# Checks if we have GPU available and uses it\n",
    "accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'\n",
    "config.trainer.devices = 1\n",
    "config.trainer.accelerator = accelerator\n",
    "\n",
    "# Reduces maximum number of epochs for quick demonstration\n",
    "config.trainer.max_epochs = 10\n",
    "\n",
    "# Remove distributed training flags\n",
    "config.trainer.strategy = 'auto'\n",
    "\n",
    "# Instantiate the trainer\n",
    "trainer = pl.Trainer(**config.trainer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "053192a4",
   "metadata": {
    "id": "053192a4"
   },
   "source": [
    "### Experiment manager\n",
    "\n",
    "NeMo has an experiment manager that handles logging and checkpointing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccb8948d",
   "metadata": {
    "id": "ccb8948d"
   },
   "outputs": [],
   "source": [
    "from nemo.utils.exp_manager import exp_manager\n",
    "\n",
    "exp_dir = exp_manager(trainer, config.get(\"exp_manager\", None))\n",
    "# The exp_dir provides a path to the current experiment for easy access\n",
    "\n",
    "print(\"Experiment directory:\")\n",
    "print(exp_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8bb583f8",
   "metadata": {
    "id": "8bb583f8"
   },
   "source": [
    "### Model instantiation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6905fea2",
   "metadata": {
    "id": "6905fea2"
   },
   "outputs": [],
   "source": [
    "from nemo.collections.audio.models.maxine import BNR2",
    "\n",
    "enhancement_model = BNR2(cfg=config.model, trainer=trainer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ee1a6d2",
   "metadata": {
    "id": "6ee1a6d2"
   },
   "source": [
    "### Training\n",
    "Create a Tensorboard visualization to monitor progress"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc9e7639",
   "metadata": {
    "id": "dc9e7639"
   },
   "outputs": [],
   "source": [
    "try:\n",
    "    from google import colab\n",
    "    COLAB_ENV = True\n",
    "except (ImportError, ModuleNotFoundError):\n",
    "    COLAB_ENV = False\n",
    "\n",
    "# Load the TensorBoard notebook extension\n",
    "if COLAB_ENV:\n",
    "    %load_ext tensorboard\n",
    "    %tensorboard --logdir {exp_dir}\n",
    "else:\n",
    "    print(\"To use tensorboard, please use this notebook in a Google Colab environment.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e962b9b",
   "metadata": {
    "id": "5e962b9b"
   },
   "source": [
    "Training can be started using `trainer.fit`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b00b3f4",
   "metadata": {
    "id": "0b00b3f4"
   },
   "outputs": [],
   "source": [
    "trainer.fit(enhancement_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b5c1e78",
   "metadata": {
    "id": "5b5c1e78"
   },
   "source": [
    "After the training is completed, the configured metrics can be easily computed on the test set as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5216195d",
   "metadata": {
    "id": "5216195d"
   },
   "outputs": [],
   "source": [
    "trainer.test(enhancement_model, ckpt_path=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf5b73e6",
   "metadata": {
    "id": "bf5b73e6"
   },
   "source": [
    "### Inference\n",
    "\n",
    "The following cell provides an example of inference on an single audio file.\n",
    "For simplicity, the audio file information is taken from the test dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b5cf923",
   "metadata": {
    "id": "6b5cf923"
   },
   "outputs": [],
   "source": [
    "# Load a single audio example from the test set\n",
    "test_metadata = read_manifest(dataset_manifest['test'].as_posix())\n",
    "\n",
    "# Path to audio files\n",
    "noisy_filepath = test_metadata[-1]['noisy_filepath'] # noisy audio\n",
    "speech_filepath = test_metadata[-1]['speech_filepath'] # clean speech\n",
    "noise_filepath = test_metadata[-1]['noise_filepath'] # corresponding noise\n",
    "\n",
    "# Load audio\n",
    "noisy_signal = AudioSegment.from_file(noisy_filepath, target_sr=sample_rate).samples\n",
    "speech_signal = AudioSegment.from_file(speech_filepath, target_sr=sample_rate).samples\n",
    "\n",
    "# Move to device\n",
    "device = 'cuda' if accelerator == 'gpu' else 'cpu'\n",
    "enhancement_model = enhancement_model.to(device)\n",
    "\n",
    "# Process using the model\n",
    "noisy_tensor = torch.tensor(noisy_signal).reshape(1, 1, -1).to(device) # (batch, channel, time)\n",
    "with torch.no_grad():\n",
    "    output_tensor = enhancement_model(input_signal=noisy_tensor)\n",
    "output_signal = output_tensor[0][0].detach().cpu().numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90f6c0d0",
   "metadata": {
    "id": "90f6c0d0"
   },
   "source": [
    "Signals can be easily plotted and signal metrics can be calculated for the given example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4aec5c64",
   "metadata": {
    "id": "4aec5c64"
   },
   "outputs": [],
   "source": [
    "# Show noisy and clean signals\n",
    "show_metrics(signal=noisy_signal, reference=speech_signal, tag='Noisy signal', sample_rate=sample_rate)\n",
    "show_metrics(signal=output_signal, reference=speech_signal, tag='Output signal', sample_rate=sample_rate)\n",
    "\n",
    "# Show signals\n",
    "show_signal(speech_signal, tag='Speech signal')\n",
    "show_signal(noisy_signal, tag='Noisy signal')\n",
    "show_signal(output_signal, tag='Output signal')\n",
    "\n",
    "# Play audio\n",
    "print('Speech signal')\n",
    "ipd.display(ipd.Audio(speech_signal, rate=sample_rate))\n",
    "\n",
    "print('Noisy signal')\n",
    "ipd.display(ipd.Audio(noisy_signal, rate=sample_rate))\n",
    "\n",
    "print('Output signal')\n",
    "ipd.display(ipd.Audio(output_signal, rate=sample_rate))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd6da1f0",
   "metadata": {
    "id": "cd6da1f0"
   },
   "source": [
    "## Next steps\n",
    "This is a simple tutorial which can serve as a starting point for prototyping and experimentation with audio-to-audio models.\n",
    "A processed audio output can be used, for example, for ASR or TTS.\n",
    "\n",
    "For more details about NeMo models and applications in in ASR and TTS, we recommend you checkout other tutorials next:\n",
    "\n",
    "* [NeMo fundamentals](https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/00_NeMo_Primer.ipynb)\n",
    "* [NeMo models](https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/01_NeMo_Models.ipynb)\n",
    "* [Speech Recognition](https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/asr/ASR_with_NeMo.ipynb)\n",
    "* [Speech Synthesis](https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/tts/Inference_ModelSelect.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46a855e3",
   "metadata": {
    "id": "46a855e3"
   },
   "source": [
    "## References\n",
    "\n",
    "[1] M. Remane, R. R. Nalla and A. Dantrey, \"SEASR: Speech Enhancement for Automatic Speech Recognition Systems using Convolution Recurrent Neural Network with Residual Connections,\" 2024 IEEE 5th Women in Technology Conference (WINTECHCON), Bengaluru, India, 2024, pp. 1-5, doi: 10.1109/Wintechcon61988.2024.10837982\n",
    "\n",
    "[2] V. Panayotov, G. Chen, D. Povery, S. Khudanpur, \"LibriSpeech: An ASR corpus based on public domain audio books,\" ICASSP 2015\n",
    "\n",
    "[3] J. Thieman, N. Ito, V. Emmanuel, \"DEMAND: collection of multi-channel recordings of acoustic noise in diverse environments,\" ICA 2013\n",
    "\n",
    "[4] https://github.com/Lightning-AI/torchmetrics\n"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
